import csv
import pdb
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns

arr = np.genfromtxt("hg00733calls.txt",delimiter='\t',dtype=str)
genotypes = np.genfromtxt("hg00733_genotypes.txt",delimiter=':',dtype=str)

#GT:NDQ:DQ:EQ:SQ:NQ:LQ:RQ:PL:RD:ORD:DSCVR

xhmmcalls = arr[:,15]
deccalls = arr[:,16]
NDQs = arr[:,10]
SQs = arr[:,9]
length = arr[:,3]
gts = arr[:,-1]

SQs = genotypes[:,4]
NQs = genotypes[:,5]
NDQs = genotypes[:,1]

SQs = np.asarray([np.float(x.replace('0','').replace(',','')) for x in SQs])
NQs = np.asarray([np.float(x.replace('0','').replace(',','')) for x in NQs])
NDQs = np.asarray([np.float(x.replace('0','').replace(',','')) for x in NDQs])



correctedids = xhmmcalls != deccalls
notcorrectedids = xhmmcalls == deccalls
hueing = ['DECoNT Corrects' if xhmmcalls[i] != gts[i] and xhmmcalls[i] != deccalls[i]  else 'DECoNT Agrees' for i in range(len(xhmmcalls))]
hueing2 = []

for i in range(len(xhmmcalls)):
	if xhmmcalls[i] != gts[i] and deccalls[i] == gts[i] and deccalls[i] != 'NO-CALL':
		hueing2.append('True Positive Correction')
	elif xhmmcalls[i] != gts[i] and deccalls[i] == gts[i] and deccalls[i] == 'NO-CALL':
		hueing2.append('True Negative Correction')
	else:
		hueing2.append('DECoNT Agrees')





'''
SQscorrected = SQs[correctedids]
SQsnotcorrected = SQs[notcorrectedids]

NQscorrected = NQs[correctedids]
NQsnotcorrected = NQs[notcorrectedids]


plt.scatter(SQscorrected, NQscorrected, s=78, c='blue',marker="^",label='DECoNT corrected true XHMM calls', alpha=0.8)
plt.scatter(SQsnotcorrected, NQsnotcorrected, s=78, c='red',marker="^",label='True XHMM calls', alpha=0.8)

plt.xlabel("SQ value")
plt.ylabel("NDQ value")

plt.legend(loc='upper left')

plt.show()

pdb.set_trace()
plt.show()

pdb.set_trace()
'''
#SQs = SQs.astype(int)
NDQs = NDQs.astype(int)


#jittering
SQs = np.asarray([x + np.random.normal(0, 2) for x in SQs])
NQs = np.asarray([x + np.random.normal(0, 2) for x in NQs])
NDQs = np.asarray([x + np.random.normal(0, 2) for x in NDQs])

pdb.set_trace()

plt.rcParams.update({'font.size': 8})
plt.tight_layout()
palette = ['#D3D3D3','black','magenta']

sns.scatterplot(x=SQs, y=NQs, alpha=0.7, hue=hueing2, palette=palette)

#plt.axhline(y=60, color='blue', linestyle='--', linewidth=0.5, label='SQ = 60')

plt.xlabel("SQ", fontweight='normal', fontsize=10)
plt.ylabel("NQ", fontweight='normal', fontsize=10)

plt.plot()


#ax.plot([0,1],[0,1], transform=ax.transAxes, color='black', linestyle='--', linewidth=0.4)

plt.show()

